// Copyright 2024 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.

package rpcserver

import (
	"context"
	"net"
	"strings"
	"testing"

	"github.com/stretchr/testify/assert"

	"github.com/google/syzkaller/pkg/csource"
	"github.com/google/syzkaller/pkg/flatrpc"
	"github.com/google/syzkaller/pkg/fuzzer/queue"
	"github.com/google/syzkaller/pkg/mgrconfig"
	"github.com/google/syzkaller/pkg/rpcserver/mocks"
	"github.com/google/syzkaller/pkg/vminfo"
	"github.com/google/syzkaller/prog"
	"github.com/google/syzkaller/sys/targets"
	"golang.org/x/sync/errgroup"
)

func getTestDefaultCfg() mgrconfig.Config {
	return mgrconfig.Config{
		Type:    targets.Linux,
		Sandbox: "none",
		Derived: mgrconfig.Derived{
			TargetOS:     targets.TestOS,
			TargetArch:   targets.TestArch64,
			TargetVMArch: targets.TestArch64,
			Timeouts:     targets.Timeouts{Slowdown: 1},
		},
	}
}

func TestNew(t *testing.T) {
	defaultCfg := getTestDefaultCfg()

	nilServer := func(s Server) {
		assert.Nil(t, s)
	}

	tests := []struct {
		name              string
		modifyCfg         func() *mgrconfig.Config
		debug             bool
		expectedServCheck func(Server)
		expectsErr        bool
		expectedErr       error
	}{
		{
			name: "unknown Sandbox",
			modifyCfg: func() *mgrconfig.Config {
				cfg := defaultCfg
				cfg.Sandbox = "unknown"
				return &cfg
			},
			expectedServCheck: nilServer,
			expectsErr:        true,
		},
		{
			name: "experimental features",
			modifyCfg: func() *mgrconfig.Config {
				cfg := defaultCfg
				cfg.Experimental = mgrconfig.Experimental{
					RemoteCover: false,
					CoverEdges:  true,
				}
				return &cfg
			},
			expectedServCheck: func(srv Server) {
				s := srv.(*server)
				assert.Equal(t, s.cfg.Config.Features, flatrpc.AllFeatures&(^flatrpc.FeatureExtraCoverage))
				assert.Nil(t, s.serv)
			},
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			cfg := tt.modifyCfg()

			var err error
			cfg.Target, err = prog.GetTarget(cfg.TargetOS, cfg.TargetArch)
			assert.NoError(t, err)

			serv, err := New(&RemoteConfig{
				Config: cfg,
				Stats:  NewStats(),
				Debug:  tt.debug,
			})
			if tt.expectedErr != nil {
				assert.Equal(t, tt.expectedErr, err)
			} else if tt.expectsErr {
				assert.Error(t, err)
			} else {
				assert.Nil(t, err)
			}
			tt.expectedServCheck(serv)
		})
	}
}

func TestCheckRevisions(t *testing.T) {
	tests := []struct {
		name    string
		req     *flatrpc.ConnectRequest
		target  *prog.Target
		noError bool
	}{
		{
			name: "error - different Arch",
			req: &flatrpc.ConnectRequest{
				Arch: "arch",
			},
			target: &prog.Target{
				Arch: "arch2",
			},
		},
		{
			name: "error - different GitRevision",
			req: &flatrpc.ConnectRequest{
				Arch:        "arch",
				GitRevision: "different",
			},
			target: &prog.Target{
				Arch: "arch",
			},
		},
		{
			name: "error - different SyzRevision",
			req: &flatrpc.ConnectRequest{
				Arch:        "arch",
				GitRevision: prog.GitRevision,
				SyzRevision: "1",
			},
			target: &prog.Target{
				Arch:     "arch",
				Revision: "2",
			},
		},
		{
			name: "ok",
			req: &flatrpc.ConnectRequest{
				Arch:        "arch",
				GitRevision: prog.GitRevision,
				SyzRevision: "1",
			},
			target: &prog.Target{
				Arch:     "arch",
				Revision: "1",
			},
			noError: true,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			err := checkRevisions(tt.req, tt.target)
			if tt.noError {
				assert.NoError(t, err)
			} else {
				assert.Error(t, err)
			}
		})
	}
}

func TestHandleConn(t *testing.T) {
	inConn, outConn := net.Pipe()
	serverConn := flatrpc.NewConn(inConn)
	clientConn := flatrpc.NewConn(outConn)

	managerMock := mocks.NewManager(t)
	debug := false
	defaultCfg := getTestDefaultCfg()

	tests := []struct {
		name       string
		wantErrMsg string
		modifyCfg  func() *mgrconfig.Config
		req        *flatrpc.ConnectRequest
	}{
		{
			name:       "error, cfg.VMLess = false - unknown VM tries to connect",
			wantErrMsg: "tries to connect",
			modifyCfg: func() *mgrconfig.Config {
				return &defaultCfg
			},
			req: &flatrpc.ConnectRequest{
				Id:          2, // Valid Runner id is 1.
				Arch:        "64",
				GitRevision: prog.GitRevision,
				SyzRevision: "1",
			},
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			cfg := tt.modifyCfg()

			var err error
			cfg.Target, err = prog.GetTarget(cfg.TargetOS, cfg.TargetArch)
			cfg.Target.Revision = tt.req.SyzRevision
			assert.NoError(t, err)

			s, err := New(&RemoteConfig{
				Config:  cfg,
				Manager: managerMock,
				Stats:   NewStats(),
				Debug:   debug,
			})
			assert.NoError(t, err)
			serv := s.(*server)

			injectExec := make(chan bool)
			serv.CreateInstance(1, injectExec, nil)
			g := errgroup.Group{}
			g.Go(func() error {
				hello, err := flatrpc.Recv[*flatrpc.ConnectHelloRaw](clientConn)
				if err != nil {
					return err
				}
				tt.req.Cookie = authHash(hello.Cookie)
				flatrpc.Send(clientConn, tt.req)
				return nil
			})
			if err := serv.handleConn(context.Background(), serverConn); err != nil {
				if !strings.Contains(err.Error(), tt.wantErrMsg) {
					t.Fatal(err)
				}
			}
			if err := g.Wait(); err != nil {
				t.Fatal(err)
			}
		})
	}
}

func TestMachineCheckCrash(t *testing.T) {
	target, err := prog.GetTarget(targets.TestOS, targets.TestArch64Fuzz)
	if err != nil {
		t.Fatal(err)
	}
	sysTarget := targets.Get(target.OS, target.Arch)
	if sysTarget.BrokenCompiler != "" {
		t.Skipf("skipping, broken cross-compiler: %v", sysTarget.BrokenCompiler)
	}
	executor := csource.BuildExecutor(t, target, "../..")

	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	checkBegan := make(chan struct{})
	cfg := &LocalConfig{
		Config: Config{
			Config: vminfo.Config{
				Target:   target,
				Features: flatrpc.FeatureSandboxNone,
				Sandbox:  flatrpc.ExecEnvSandboxNone,
			},
			Procs:               4,
			Slowdown:            1,
			machineCheckStarted: checkBegan,
		},
		Executor: executor,
		Dir:      t.TempDir(),
	}
	cfg.MachineChecked = func(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source {
		cancel()
		return queue.Callback(func() *queue.Request {
			return nil
		})
	}

	local, ctx, err := setupLocal(ctx, cfg)
	if err != nil {
		t.Fatal(err)
	}
	loopDone := make(chan error)
	go func() {
		loopDone <- local.Serve(ctx)
	}()

	t.Logf("starting the first instance")
	firstCtx, firstCancel := context.WithCancel(ctx)
	firstCh := make(chan error)
	go func() {
		firstCh <- local.RunInstance(firstCtx, 0)
	}()

	t.Logf("wait for the machine check to begin")
	<-checkBegan

	t.Logf("kill the first instance")
	firstCancel()
	if err := <-firstCh; err != nil {
		t.Fatal(err)
	}

	t.Logf("restart the instance")
	secondCh := make(chan error)
	go func() {
		secondCh <- local.RunInstance(ctx, 0)
	}()

	t.Logf("await the completion")
	if err := <-loopDone; err != nil {
		t.Fatal(err)
	}
	if err := <-secondCh; err != nil {
		t.Fatal(err)
	}
}
